import datetime
import gc
from concurrent.futures import ThreadPoolExecutor
from functools import cached_property, partial
from time import sleep
from typing import Any, Callable, Generator, Iterable, cast

from datasets.fingerprint import Hasher
from filelock import FileLock
from openai import NotFoundError
from sqlitedict import SqliteDict

from .. import DataDreamer, __version__
from ..utils.fs_utils import safe_fn
from .llm import DEFAULT_BATCH_SIZE, _check_max_new_tokens_possible
from .openai import OpenAI, OpenAIException


class OpenAIAssistant(OpenAI):
    def __init__(
        self,
        model_name: str,
        system_prompt: None | str = None,
        tools: None | list[dict] = None,
        organization: None | str = None,
        api_key: None | str = None,
        base_url: None | str = None,
        api_version: None | str = None,
        retry_on_fail: bool = True,
        cache_folder_path: None | str = None,
        **kwargs,
    ):
        super().__init__(
            model_name=model_name,
            system_prompt=system_prompt,
            organization=organization,
            api_key=api_key,
            base_url=base_url,
            api_version=api_version,
            retry_on_fail=retry_on_fail,
            cache_folder_path=cache_folder_path,
            **kwargs,
        )
        self.tools = tools or []

    @cached_property
    def assistant_id(self) -> str:
        from openai.types.beta import Assistant

        assistant_id: None | str = None
        assistant: None | Assistant = None
        if self.cache_and_lock:
            cache, lock = cast(tuple[SqliteDict, FileLock], self.cache_and_lock)
            assistant_id = cache.get("assistant_id", None)
        if assistant_id is not None:  # pragma: no cover
            try:
                assistant = self.client.beta.assistants.retrieve(
                    assistant_id=assistant_id
                )
            except NotFoundError:
                pass
        if assistant is None:
            date = datetime.datetime.now()
            date_str = date.strftime("%b %d, %Y %I:%M %p")
            assistant = self.client.beta.assistants.create(
                model=self.model_name,
                description="Automatically generated by DataDreamer.",
                instructions=self.system_prompt,
                metadata={
                    "datadreamer_version": __version__,
                    "version": str(self.version),
                    "_cache_name": self._cache_name,
                },
                name=f"DataDreamer Assistant (created on {date_str})",
                tools=self.tools,  # type: ignore[arg-type]
            )
        assert assistant is not None
        if self.cache_and_lock:
            with lock:
                cache["assistant_id"] = assistant.id
                cache.commit()
        if DataDreamer.initialized():
            DataDreamer._add_cleanup_func(
                partial(
                    lambda self, assistant_id: self.client.beta.assistants.delete(
                        assistant_id=assistant_id
                    ),
                    self,
                    assistant.id,
                )
            )
        return assistant.id

    def _run_batch(  # type:ignore[override]
        self,
        max_length_func: Callable[[list[str]], int],
        system_prompt: str,
        inputs: list[str],
        max_new_tokens: None | int = None,
        temperature: float = 1.0,
        top_p: float = 0.0,
        n: int = 1,
        stop: None | str | list[str] = None,
        repetition_penalty: None | float = None,
        logit_bias: None | dict[int, float] = None,
        batch_size: int = DEFAULT_BATCH_SIZE,
        seed: None | int = None,
        **kwargs,
    ) -> list[str] | list[list[str]]:
        prompts = inputs

        # Check max_new_tokens
        max_new_tokens = _check_max_new_tokens_possible(
            self=self,
            max_length_func=max_length_func,
            prompts=prompts,
            max_new_tokens=max_new_tokens,
        )

        # Run the model
        def get_generated_texts(self, kwargs, prompt) -> list[str]:
            thread = self.client.beta.threads.create(
                messages=[{"role": "user", "content": prompt}]
            )
            run = self.client.beta.threads.runs.create(
                thread_id=thread.id, assistant_id=self.assistant_id
            )
            while run.status not in [
                "completed",
                "requires_action",
                "cancelled",
                "failed",
                "expired",
            ]:
                run = self.client.beta.threads.runs.retrieve(
                    thread_id=thread.id, run_id=run.id
                )
                sleep(0.5)
            if run.status != "completed":  # pragma: no cover
                if run.status == "requires_action":
                    raise Exception(
                        f"OpenAI Assistant did not complete with status: {run.status}"
                    )
                else:
                    raise OpenAIException(
                        f"OpenAI Assistant did not complete with status: {run.status}"
                    )
            thread_messages = self.client.beta.threads.messages.list(thread.id)
            self.client.beta.threads.delete(thread_id=thread.id)
            return [
                "\n\n".join(
                    [
                        m.content[0].text.value.strip()
                        for m in sorted(
                            thread_messages.data, key=lambda m: m.created_at
                        )
                        if m.role == "assistant"
                    ]
                )
            ]

        if batch_size not in self.executor_pools:
            self.executor_pools[batch_size] = ThreadPoolExecutor(max_workers=batch_size)
        generated_texts_batch = list(
            self.executor_pools[batch_size].map(
                partial(get_generated_texts, self, kwargs), prompts
            )
        )
        if n == 1:
            return [batch[0] for batch in generated_texts_batch]
        else:  # pragma: no cover
            return generated_texts_batch

    def run(  # type:ignore[override]
        self,
        prompts: Iterable[str],
        batch_size: int = DEFAULT_BATCH_SIZE,
        batch_scheduler_buffer_size: None | int = None,
        adaptive_batch_size: bool = False,
        progress_interval: None | int = 60,
        force: bool = False,
        cache_only: bool = False,
        verbose: None | bool = None,
        log_level: None | int = None,
        total_num_prompts: None | int = None,
        return_generator: bool = False,
        **kwargs,
    ) -> Generator[str | list[str], None, None] | list[str | list[str]]:
        return super().run(
            prompts=prompts,
            max_new_tokens=None,
            temperature=1.0,
            top_p=0.0,
            n=1,
            stop=None,
            repetition_penalty=None,
            logit_bias=None,
            batch_size=batch_size,
            batch_scheduler_buffer_size=batch_scheduler_buffer_size,
            adaptive_batch_size=adaptive_batch_size,
            seed=None,
            progress_interval=progress_interval,
            force=force,
            cache_only=cache_only,
            verbose=verbose,
            log_level=log_level,
            total_num_prompts=total_num_prompts,
            return_generator=return_generator,
            **kwargs,
        )

    @cached_property
    def _cache_name(self) -> None | str:
        names = [safe_fn(self.model_name, allow_slashes=False)]
        to_hash: list[Any] = []
        to_hash.append(self.system_prompt)
        to_hash.append(sorted(self.tools, key=lambda t: Hasher.hash(t)))
        names.append(Hasher.hash(to_hash))
        return "_".join(names)

    def unload_model(self):
        super().unload_model()

        # Delete cached assistant_id
        if "assistant_id" in self.__dict__:
            del self.__dict__["assistant_id"]

        # Garbage collect
        gc.collect()


__all__ = ["OpenAIAssistant"]
